#!/bin/bash

eval "$(conda shell.bash hook)"
conda activate reil || exit 1

set -x


DATA_DIR="./data/sokoban_one_horizon_large_envs"
# BASE_MODEL="./models/rlft/models--Qwen--Qwen2.5-3B-Instruct/snapshots/aa8e72537993ba99e69dfaafa59ed015b17504d1"
# BASE_MODEL="./models/rlft/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775"
# BASE_MODEL="./models/rlft/models--Qwen--Qwen2.5-1.5B-Instruct/snapshots/989aa7980e4cf806f80c7fef2b1adb7bc71aa306"
# BASE_MODEL="./models/rlft/models--Qwen--Qwen2.5-3B/snapshots/3aab1f1954e9cc14eb9509a215f9e5ca08227a9b"
INFERENCE_MODEL="./checkpoints/ds543/sft/sokoban-1.5b-sft-qwen-2.5-1.5b-base-full-sft/global_step_225"
# INFERENCE_MODEL="./checkpoints/REIL/exp-1.5b-0.005beta-logic-with-kl-0.001-1024-ctx-one-horizon/huggingface"
BETA=0.005
KL_COEF=0.001
CONTEXT_LENGTH=1024
EXPERIMENT_NAME="inference-1.5b-${INFERENCE_MODEL}"
ROLLOUT_TP_SIZE=1
N_GPUS=1
export CUDA_VISIBLE_DEVICES=3
export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m reil.trainer.main_ppo \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
data.val_env_files=$DATA_DIR/test_env.parquet \
data.train_batch_size=256 \
data.val_batch_size=32 \
max_turns=10 \
+trainer.val_only=True \
env.name=sokoban_reil \
data.max_prompt_length=1000 \
data.max_response_length=$CONTEXT_LENGTH \
actor_rollout_ref.model.path=$INFERENCE_MODEL \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
actor_rollout_ref.actor.entropy_coeff=${BETA} \
actor_rollout_ref.rollout.log_prob_micro_batch_size=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
critic.optim.lr=1e-5 \
critic.model.path=$INFERENCE_MODEL \
critic.ppo_micro_batch_size=16 \
algorithm.use_kl_in_reward=True \
algorithm.kl_ctrl.kl_coef=${KL_COEF} \
trainer.logger=['wandb'] \
trainer.val_before_train=True \
trainer.default_hdfs_dir=null \
trainer.n_gpus_per_node=$N_GPUS \
trainer.nnodes=1 \
trainer.save_freq=1000 \
trainer.test_freq=25 \
trainer.project_name=REIL \
trainer.resume_mode=auto \
trainer.log_val_generations=4 \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.total_epochs=500 \
trainer.is_rl_validation=True \
reward_model.reward_manager=complete \
custom_reward_function.path=./reil/utils/reward_score/sokoban.py \
custom_reward_function.name=compute_score_with_action_sequence 2>&1 | tee verl_demo_1.5b.log